import numpy as np
import matplotlib.pyplot as plt
import os
from matplotlib.colors import LogNorm
from dataclasses import dataclass

fname = 'figS4.npz'

@dataclass
class line_package:
    nu: float = 0
    R_cut: float = 0
    n_cut: float = 0
    ns: float = 0
    color: str = 'k'
    label: str = ''
    R_shift: float = 0

@dataclass
class Dataset:
    device: object = object()
    data_dir: str = ''
    fnums: iter = range(1)
    fn_npz: str = ''
    nx: int = 0
    ny: int = 0
    Vbg_col: int = 0
    Vtg_col: int = 0
    R_col: int = 0
    B_col = int = 0
    n0: float = 0
    D0: float = 0
    B: float = 0
    npz_dir: str = '.'
    Vg: float = 0

@dataclass
class Device:
    dtg: float
    dbg: float
    ep: float
    name: str

# load npz file
data = np.load(fname, allow_pickle=True)
print(data.files)
line_packages = data['line_packages']
datasets = data['datasets']

# plot line cuts of Rxx vs. nu
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
for ID, line_package in enumerate(line_packages):
    nu = line_package.nu
    R_cut = line_package.R_cut
    R_shift = line_package.R_shift
    label = line_package.label
    color = line_package.color
    ax.plot(nu, R_cut+R_shift, label=label, color=color)
ax.set_xlabel(r'$\nu$')
ax.set_ylabel(r'$R_{xx}$')
ax.legend()
ax.set_xlim([-4, 4])
fig.suptitle('Fig. S4')
plt.show()

# plot Rxx(n,D) or Rxx(Vtg,Vbg) maps of all the different datasets
plot_nD = True #choose to plot vs. (n,D) or (Vtg,Vbg)
device_names = ['A', 'B', 'C', 'E', 'D3', 'D1', 'D2', 'F']
for ID, dataset in enumerate(datasets):
    n = dataset.n
    D = dataset.D
    Vtg = dataset.Vtg_arr
    Vbg = dataset.Vbg_arr
    Rxx = dataset.data_arr[:,:,dataset.R_col]/dataset.Iac*dataset.Rmult 

    # plot
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    if not plot_nD:
        try:
            im = ax.scatter(Vtg, Vbg, c=Rxx, marker='s', s=3, norm=LogNorm()) 
        except:
            im = ax.scatter(Vtg, Vbg, c=Rxx.T, marker='s', s=3, norm=LogNorm()) #transpose if needed
        ax.set_xlabel(r'$V_{tg}$ (V)')
        ax.set_ylabel(r'$V_{bg}$ (V)')
    else:
        try:
            im = ax.scatter(n, D, c=Rxx, marker='s', s=3, norm=LogNorm()) 
        except:
            im = ax.scatter(n, D, c=Rxx.T, marker='s', s=3, norm=LogNorm()) #transpose if needed
        ax.set_xlabel(r'$n$ ($10^{12}$ cm$^{-2}$)')
        ax.set_ylabel(r'$D/\epsilon_0$ (V/nm)')
    fig.suptitle('Device ' + device_names[ID])
    plt.colorbar(mappable=im)
    plt.show()
